import os
import time
import yaml
import joblib
import pickle
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.preprocessing import image
from skimage.feature import hog
from skimage import color
import numpy as np
from PIL import Image


# 加载配置文件
with open('0_setting.yaml', 'r', encoding='utf-8') as f:
    setting = yaml.load(f, Loader=yaml.Loader)


def save_pickle(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)


# 获取配置文件中的键对应的值
def get(key):
    return setting[key]


# 图像预处理函数
def preprocess_image(file_name, new_size):
    img = Image.open(file_name)
    img = img.resize(new_size, Image.ANTIALIAS)

    # 转化为灰度图像
    img = color.rgb2gray(np.array(img))

    # 提取 HOG 特征
    features = hog(img, orientations=9, pixels_per_cell=(8, 8),
                   cells_per_block=(2, 2), block_norm='L2-Hys')

    return features


# 序列化对象并保存到本地
def dump(obj, name, loc):
    file_dir = os.path.split(loc)[0]
    # 判断文件夹路径是否存在，如果不存在，则创建
    if not os.path.isdir(file_dir):
        os.makedirs(file_dir)
    start = time.time()
    print(f"Saving {name} to {loc}...")
    joblib.dump(obj, loc)
    end = time.time()
    print(f"Saved! Location: {loc}, Size: {os.path.getsize(loc) / 1024 / 1024:.3f}M")
    print(f"Time taken: {end - start:.3f} seconds")


# 从本地加载对象并反序列化
def load(name, loc):
    print(f"Loading {name} from {loc}...")
    obj = joblib.load(loc)
    return obj


def extract_vgg16_features(img_path):
    model = VGG16(weights='imagenet', include_top=False)
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    features = model.predict(x)
    features_flatten = features.reshape(features.shape[0], -1)
    return features_flatten


def extract_combined_features(img_path, new_size):
    hog_features = preprocess_image(img_path, new_size)
    vgg16_features = extract_vgg16_features(img_path)
    combined_features = np.concatenate([hog_features, vgg16_features], axis=1)
    return combined_features